import sys

import torch

sys.path.append("src")
from dataset.dataset import get_loaders
from pruner.utils import find_layers


def llmpruner_1_importance(
    model,
    tokenizer,
    n_samples,
    seed,
    device=torch.device("cuda:0"),
    pretrained_dataset_name="c4",
):
    """
    Compute the wanda importance of each layer in the model.
    """
    W_metrics = {}

    pretrained_dataloader, _ = get_loaders(
        name=pretrained_dataset_name,
        nsamples=n_samples,
        seed=seed,
        seqlen=model.seqlen,
        tokenizer=tokenizer,
    )
    try:
        layers = model.model.layers
    except:
        layers = model.model.decoder.layers
    model.train()

    ### calculate pretrained gradients for each layer
    model.zero_grad()
    import tqdm

    for batch in tqdm.tqdm(
        pretrained_dataloader, desc="pretrained gradients calculation"
    ):
        batch = tuple(t.to(device) for t in batch)
        inputs = {"input_ids": batch[0], "labels": batch[0]}
        outputs = model(**inputs)
        loss = outputs[0]
        loss.backward()

    for name, param in model.named_parameters():
        param.grad = param.grad / n_samples

    #### store pretrained gradients to W_metrics
    cnt = 0
    for i in range(len(layers)):
        layer = layers[i]
        subset = find_layers(layer)
        for name in subset:
            W_metrics[cnt] = torch.abs(
                subset[name].weight.grad.detach() * subset[name].weight.data
            )
            cnt += 1

    torch.cuda.empty_cache()
    # print(len(W_metrics))
    # exit(0)
    return W_metrics
